import torch
from collections import defaultdict


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.min = 1e+8
        self.max = -1e+8

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.min = min(self.avg, self.min)
        self.max = max(self.avg, self.max)

    def __str__(self):
        return "{self.avg:.4f}".format(self.avg)


class Logger():

    def __init__(self):
        self.keys = list()

    def update(self, bsz, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            # if k not in self.m.keys():
            if not hasattr(self, k):
                # self.m[k] = AverageMeter()
                setattr(self, k, AverageMeter())
            getattr(self, k).update(v, bsz)
            if k not in self.keys:
                self.keys.append(k)
            # self.m[k].update(v, bsz)

    def reset(self):
        for k in self.keys:
            getattr(self, k).reset()

    def out(self):
        d = dict()
        s = ''
        for k in self.keys:
            attr = getattr(self, k)
            v = attr.val; a = attr.avg
            d[f'log/{k}.val'] = v; d[f'log/{k}.avg'] = a
            s += f'{k} {v:.3f} ({a:.3f})\t'
        return d, s



